Deep Dive: Simulating Morphologically Accurate Neurons with CRANTpy

Deep Dive: Simulating Morphologically Accurate Neurons with CRANTpy#

This tutorial will guide you through the process of simulating morphologically accurate neurons using the CRANTpy library. We will cover the installation of necessary packages, loading neuron morphology data, setting up the simulation environment, and running a basic simulation.

# Import CRANTpy and other necessary libraries
import crantpy as cp
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import navis

# Set up logging to see progress
cp.set_logging_level("INFO")

print("CRANTpy loaded successfully!")
print(f"Default dataset: {cp.CRANT_DEFAULT_DATASET}")
CRANTpy loaded successfully!
Default dataset: latest

1. Authentication Setup#

Before we can access the data, we need to authenticate with the CAVE service. This is typically a one-time setup.

# Generate and save authentication token (uncomment if first time)
# cp.generate_cave_token(save=True)

# Test connection
try:
    client = cp.get_cave_client()
    print(f"Successfully connected to datastack: {client.datastack_name}")
    print(f"Server: {client.server_address}")
except Exception as e:
    print(f"Connection failed: {e}")
    print("Please run: cp.generate_cave_token(save=True)")
Successfully connected to datastack: kronauer_ant
Server: https://proofreading.zetta.ai

NEURON Installation#

To simulate neurons, we need to install the python interface for NEURON, a popular simulation environment for modeling individual neurons and networks of neurons.

!pip install neuron -q

import neuron
import navis.interfaces.neuron as nrn
# load a sample PN from crantpy
opn_criteria = cp.NeuronCriteria(cell_class='olfactory_projection_neuron', side='right', tract='mALT')
opn = opn_criteria.get_roots()[0]

# get the skeleton
n = cp.get_l2_skeleton(opn)
n = cp.reroot_at_soma(n)
n.soma = n.root
n = cp.attach_synapses(n, clean=True, max_distance=1000)
2025-10-07 18:08:24 - INFO - Fetching L2 skeleton for root ID: 576460752773799604
2025-10-07 18:08:36 - INFO - Attached 6592 connectors to neuron 576460752773799604
n = navis.split_axon_dendrite(n, label_only=True, cellbodyfiber="soma")

# change unit to microns from nm
n = n * 1e-3
# Create the compartment model
cmp = nrn.CompartmentModel(n, res=10)

# Set the specific axial resistivity for the entire neuron in Ohm cm
cmp.Ra = 266.1

# Set the specific membran capacitance in mF / cm**2
cmp.cm = 0.8

# Add passive membran properties for the entire neuron
cmp.insert(
    "pas",
    g=1
    / 20800,  # specific leakage conductance = 1/Rm; Rm = specific membran resistance in Ohm cm**2
    e=-60,  # leakage reverse potential
)
import matplotlib.pyplot as plt

fig, ax = navis.plot2d(
    skel, color_by="compartment", palette="tab10", lw=1.5
)

# Label compartments
for l in skel.nodes.compartment.unique():
    loc = skel.nodes.loc[skel.nodes.compartment == l, ["x", "y", "z"]].values[-1]
    ax.text(loc[0] + 10, loc[1], l)

plt.tight_layout()
../_images/cd8b1aa25236b4c38e08d86ad19551490d8d3ba3d4fa66c00a938e2b84ef0677.png
# Collect axon nodes
axon_nodes = n.nodes.loc[n.nodes.compartment.isin(["axon", "linker"]), "node_id"].values

# Get the sections for the given nodes
axon_secs = list(set(cmp.get_node_section(axon_nodes)))

# Insert HH mechanism at the given sections
cmp.insert("hh", subset=axon_secs)
dists = navis.geodesic_matrix(n, from_=n.soma)

# Sort by distance from soma
dists = dists.iloc[0].sort_values()

dists.head(10)
400     0.000000
398     2.574118
397     2.775527
399     3.277038
358     5.067570
357    14.071592
291    20.047707
290    27.316755
254    32.438854
256    37.647346
Name: 400, dtype: float32
# Find the closest "dendrite" and the most distal "axon" node
dend = n.nodes[n.nodes.compartment == "dendrite"].node_id.values
dend_base = dists.index[dists.index.isin(dend)][0]
print(f"Node at the base of the dendrites: {dend_base}")

axo = n.nodes[n.nodes.compartment == "axon"].node_id.values
axo_tip = dists.index[dists.index.isin(axo)][-1]
print(f"Node at the tip of the axon: {axo_tip}")
Node at the base of the dendrites: 131
Node at the tip of the axon: 17
# Add voltage recordings
cmp.add_voltage_record(dend_base, label="dendrite_base")
cmp.add_voltage_record(axo_tip, label="axon_tip")
cmp.add_voltage_record(n.soma, label="soma")
# Get dendritic postsynapses
post = n.postsynapses[n.postsynapses.compartment == "dendrite"]
post.head()
connector_id x y z partner_id type node_id compartment
6459 6459 253.904 287.936 104.076 576460752650983024 post 135 dendrite
6460 6460 253.152 289.888 105.210 576460752769017234 post 136 dendrite
6461 6461 251.328 284.928 106.806 576460752705584683 post 119 dendrite
6462 6462 250.960 285.920 102.774 576460752688729301 post 99 dendrite
6463 6463 252.640 286.624 101.724 576460752688729301 post 132 dendrite
# Here we will open successively more synapses over 5 stimulations
for i in range(5):
    # Onset for this stimulation
    start = 50 + i * 200
    # Number of synapses to activate
    n_syn = i * 5
    cmp.add_synaptic_current(
        where=post.node_id.unique()[0:n_syn], start=start, max_syn_cond=0.1, rev_pot=-10
    )
# Now we can run our simulation for 1000ms
# (this is equivalent to neuron.h.finitialize + neuron.h.continuerun)
cmp.run_simulation(1000, v_init=-60)
# Plot the results
axes = cmp.plot_results()
../_images/0d507a727e75b43d963981a637b6c1776a0275b5a0af8e8506cf1f09699bf37e.png
# Plot again and zoom in on one spike
axes = cmp.plot_results()
axes[0].set_xlim(240, 280)
(240.0, 280.0)
../_images/3988852ed51cdfbcf68d9b5a6ca3879754850122f6a75bc496039dd55b058e9e.png
# First we need to reset our model (by re-assigning `cmp` the old model will be garbage-collected)
cmp = nrn.CompartmentModel(n, res=10)

# Set properties and mechanisms
cmp.Ra, cmp.cm = 266.1, 0.8
cmp.insert("pas", g=1 / 20800, e=-60)
axon_secs = list(set(cmp.get_node_section(axon_nodes)))
cmp.insert("hh", subset=axon_secs)

# Add recording
cmp.add_voltage_record(dend_base, label="dendrite_base")
cmp.add_voltage_record(axo_tip, label="axon_tip")
cmp.add_voltage_record(n.soma, label="soma")

# Also add a spike counter at the axon
cmp.add_spike_detector(axo_tip, label="axon_tip")

# Now add a noisy preinput to say 20 dendritic postsynapses
post = n.connectors[
    (n.connectors.compartment == "dendrite") & (n.connectors.type == "post")
]
cmp.add_synaptic_input(
    post.node_id.unique()[0:20],
    spike_no=20,  # produce 20 presynaptic spikes
    spike_int=50,  # with an average interval of 50ms: 20 * 50ms = over 1s
    spike_noise=1,  # very noisy!
    cn_weight=0.04,
)
# Run for 1s
cmp.run_simulation(1000, v_init=-60)
# Plot results
axes = cmp.plot_results()
axes[1].set_ylabel("spikes [Hz]")
Text(0, 0.5, 'spikes [Hz]')
../_images/e6936a1438c298add709519b34b477aaec7fa37f6e184b3e83254502ec090593.png
# First we need to reset our model (by re-assigning `cmp` the old model will be garbage-collected)
cmp = nrn.CompartmentModel(n, res=10)

# Set properties and mechanisms
cmp.Ra, cmp.cm = 266.1, 0.8
cmp.insert("pas", g=1 / 20800, e=-60)
axon_secs = list(set(cmp.get_node_section(axon_nodes)))
cmp.insert("hh", subset=axon_secs)

# Add recording
cmp.add_voltage_record(n.nodes.node_id.values)

# Get dendritic postsynapses
post = n.postsynapses[n.postsynapses.compartment == "dendrite"]

# Add synaptic input to the first 10 postsynapses after 2 ms
cmp.add_synaptic_current(where=post.node_id.unique()[0:10], start=2, max_syn_cond=0.1, rev_pot=-10)

# This is equivalent to neuron.h.finitialize + neuron.h.continuerun
cmp.run_simulation(100, v_init=-60)
import numpy as np
import pandas as pd

# Collect the voltage recordings at each node
records = pd.DataFrame(np.vstack([r.as_numpy() for r in cmp.records['v'].values()]), index=list(cmp.records['v'].keys()))

# Reindex to make sure it matches the node table
records = records.reindex(n.nodes.node_id)

records.head()
0 1 2 3 4 5 6 7 8 9 ... 3991 3992 3993 3994 3995 3996 3997 3998 3999 4000
node_id
0 -60.0 -60.265901 -60.522424 -60.770669 -61.011606 -61.246079 -61.474742 -61.698152 -61.916781 -62.131019 ... -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368
3 -60.0 -60.265901 -60.522424 -60.770669 -61.011606 -61.246079 -61.474742 -61.698152 -61.916781 -62.131019 ... -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368
1 -60.0 -60.265901 -60.522424 -60.770669 -61.011606 -61.246079 -61.474742 -61.698152 -61.916781 -62.131019 ... -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368
4 -60.0 -60.265901 -60.522424 -60.770669 -61.011606 -61.246079 -61.474742 -61.698152 -61.916781 -62.131019 ... -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368
6 -60.0 -60.265901 -60.522424 -60.770669 -61.011606 -61.246079 -61.474742 -61.698152 -61.916781 -62.131019 ... -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368 -64.778368

5 rows × 4001 columns

# plot the data
import matplotlib.pyplot as plt

for idx, row in records.iterrows():
    plt.plot(row.values + idx * 6, color="k", lw=0.5)  # offset each trace for visibility
plt.xlabel("Time (ms)")
plt.ylabel("Node ID (offset for visibility)")
plt.title("Voltage traces at each node")
plt.show()
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[1], line 4
      1 # plot the data
      2 import matplotlib.pyplot as plt
----> 4 for idx, row in records.iterrows():
      5     plt.plot(row.values + idx * 6, color="k", lw=0.5)  # offset each trace for visibility
      6 plt.xlabel("Time (ms)")

NameError: name 'records' is not defined
n.nodes['v'] = records.loc[:, int(5 / 0.025)].values

# Plot
fig, ax = navis.plot2d(
    n,
    method="2d",
    color_by="v",  # color by the voltage column
    palette="viridis",
    vmin = -70,
    vmax = 10,
    view=('x', '-y')
)

# Manually add a colorbar
import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
sm = ScalarMappable(norm=plt.Normalize(vmin=-70, vmax=10), cmap='viridis')
_ = fig.colorbar(sm, ax=ax, fraction=0.075, shrink=0.5, label="V")
../_images/7e82ea6d8431b94b62dc52a2e54dfbc90e9c75123c5befeb235f09e5142a5729.png
# Convert our skeleton to a mesh for nicer visualization
mesh = navis.conversion.tree2meshneuron(n, warn_missing_radii=False)

# Plot the neuron
fig, ax = navis.plot2d(mesh, method='2d',color='k', view=('x','-y'))

sm = ScalarMappable(norm=plt.Normalize(vmin=-70, vmax=10), cmap='viridis')
_ = fig.colorbar(sm, ax=ax, fraction=0.075, shrink=0.5, label="V")

# Add a text in the top right for the timestamp
t = ax.text(0.02, 0.95, 'ms', ha='left', va='top', transform=ax.transAxes, color='r')

# Get the collection representing our neuron
c = ax.collections[0]
c.set_cmap('viridis')
c.set_norm(plt.Normalize(vmin=-70, vmax=10))

# This function updates the voltages according to the frame
def animate(i):
    # We need to map the voltages at individual nodes to faces in the mesh
    # First nodes to vertices
    vert_voltage = records[i].values[mesh.vertex_map]
    # Then vertices to faces
    face_voltage = vert_voltage[mesh.faces].mean(axis=1)
    # Set the values
    c.set_array(face_voltage)
    # Also update the timestamp
    t.set_text(f'{i * 0.025:.2f} ms')
    return (c, t)

import matplotlib.animation as animation
ani = animation.FuncAnimation(fig, animate, interval=40, blit=True, repeat=True, frames=400)
from IPython.display import HTML
HTML(ani.to_jshtml())
2025-10-07 18:18:05 - INFO - Animation.save using <class 'matplotlib.animation.HTMLWriter'>
../_images/1b85f90e85bf08c0ce80ce774a15d60b610f1291a21f22bbbe455e9a2b545ef7.png